"""
create train or eval dataset.
"""
import os
from functools import partial
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.py_transforms as P
from mindspore.communication.management import init, get_rank, get_group_size
from src.config import config_quant

config = config_quant

def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        do_train(bool): whether dataset is used for train or eval.
        repeat_num(int): the repeat times of dataset. Default: 1
        batch_size(int): the batch size of dataset. Default: 32
        target(str): the device target. Default: Ascend

    Returns:
        dataset
    """
    device_num = int(os.getenv("DEVICE_MUM"))
    rank_id = int(os.getenv("RANK_ID"))

    if device_num ==1:
        ds = de.Cifar10Dataset(dataset_path,num_parallel_workers=8,shuffle=True)
    else:
        ds = de.Cifar10Dataset(dataset_path,num_parallel_workers=8,shuffle=True,
                               num_samples=device_num,shard_id=rank_id)
    resize_height = config.image_height
    resize_width = config.image_width
    rescale = 1.0/255.0
    shift = 0.0

    #define map operations
    random_crop_op = C.RandomCrop((32,32),(4,4,4,4))
    randm_horizontal_flip_op = C.RandomHorizontalFlip(rank_id/(rank_id+1))

    resize_op = C.Resize((resize_height,resize_width))
    rescale_op = C.Rescale(rescale,shift)
    normalize_op = C.Normalize([0.4914,0.4822,0.4465],[0.2023,0.1994,0.2010])

    change_swap_op = C.HWC2CHW()

    trans = []
    if do_train:
        trans += [random_crop_op,randm_horizontal_flip_op]
    trans += [resize_op,rescale_op,normalize_op,change_swap_op]

    type_cast_op = C2.TypeCast(mstype.int32)

    ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
    ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)

    #apply batch operations
    ds = ds.batch(batch_size,drop_remainder=True)
    ds = ds.repeat(repeat_num)
    return ds